Skip to content

make sort32 fast#327

Open
39ali wants to merge 4 commits into
sparkjsdev:mainfrom
39ali:sort32-fast
Open

make sort32 fast#327
39ali wants to merge 4 commits into
sparkjsdev:mainfrom
39ali:sort32-fast

Conversation

@39ali
Copy link
Copy Markdown

@39ali 39ali commented Apr 29, 2026

try to improve the performance of sort32, on avg it's 30-40% faster .

things that changed :

  • pass 2 no longer re-reads keys[] , scratch stores a packed u64 of (inverted_key << 32 | original_index). pass 2 reads the high 16 bits directly from scratch with kv >> 48 making it a sequential scan

  • histogram and scatter are now branchless to help llvm vectorize the loop

  • manually unrolled histogram and both scatter passes to 8-wide

Comment thread rust/spark-worker-rs/src/sort.rs Outdated
/// Two‑pass radix sort (base 2¹⁶) of 32‑bit float bit‑patterns,
/// descending order (largest keys first). Mirrors the JS `sort32Splats`.
#[inline(always)]
unsafe fn prefix_sum_exclusive(buckets: &mut [u32]) -> u32 {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason this is marked unsafe? It compiles just fine without.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i had many experiments with simd, which didn't make it marginally faster so i removed it for simplicity sake but forgot to remove the unsafe, will clean it up

@mrxz
Copy link
Copy Markdown
Collaborator

mrxz commented Apr 29, 2026

Awesome work, gave it a try and can confirm that it improves sorting performance. In my limited testing I saw ~20% reduction in sorting time (~25% faster).

manually unrolled histogram and both scatter passes to 8-wide

Without this change the performance gain seems to be roughly the same, or at least I didn't observe any significant difference. The majority of the benefit seems to come from making it branchless.

@39ali
Copy link
Copy Markdown
Author

39ali commented Apr 29, 2026

@mrxz i squeezed a bit more performance ~<=1ms by removing more branches from hot loops, and what you noticed seems about right, it will differ from one wasm engine to another, and arch to another(specially cache sizes and arch) so it's hard to give a solid number but it'll still be a pump in performance

@39ali 39ali force-pushed the sort32-fast branch 2 times, most recently from 8c1efb4 to 77253d1 Compare April 29, 2026 17:42
@dmarcos
Copy link
Copy Markdown
Contributor

dmarcos commented May 14, 2026

can you remove the changes in the dist directory?

@asundqui
Copy link
Copy Markdown
Contributor

@39ali great work! This looks like a cool win indeed, thanks for the work. Could remove the build in the dist/ folder from your branch, then we can merge it in?

@mrxz
Copy link
Copy Markdown
Collaborator

mrxz commented May 15, 2026

Could the macros used for unrolling the loops also be used for the body of remainder loops? Both should be identical, so if we could avoid the duplication we avoid the risk of it ever getting out of sync.

@dmarcos
Copy link
Copy Markdown
Contributor

dmarcos commented May 15, 2026

@39ali any chance for you to implement @mrxz suggestions? Thanks so much for the contribution

@39ali
Copy link
Copy Markdown
Author

39ali commented May 26, 2026

I'll implement the changes

@39ali
Copy link
Copy Markdown
Author

39ali commented Jun 2, 2026

@mrxz @dmarcos @asundqui done !

…the essential optimizations. Removed second branchless optimization. Added comments on why `unsafe` accesses are okay.
@asundqui
Copy link
Copy Markdown
Contributor

asundqui commented Jun 6, 2026

@39ali really great work here! I've done some benchmarking on your method, and I'm actually getting 2x - 4x speedups in sorting from this. I'm truly shocked that this was possible! This will have a great impact on Spark's sorting performance. On a 10M splat scene on my M3 it goes from 250 ms to 60 ms or so. It's possible that the speedup is not as great on other environments, such as @mrxz was reporting 25% speedups on his system.

I went through and carefully separated the optimizations and measured them:

  • I found that approx 65% of the gain could be explained by storing (key, index) as a packed u64 in scratch, which turned the next loop from a random gather into a sequential read. So much better cache performance as a result.

  • The next 20% came from doing unsafe { unchecked... } array accesses where the compiler couldn't be sure that it would always be in bounds, so it has to check it every iteration. I went through the logic and it looks like it should always be in bounds.

  • The final 15% or so came from unrolling the loops. I had thought the unrolling couldn't possibly help because of branch prediction + cpu instruction reordering but it all helps!

It did seem like there was one error though: the second branchless loop seems problematic... I think writing to the array and only advancing the pointer if it's "valid" could overwrite things. So I removed it. I don't think it does very much for the performance anyway.

Finally I reverted some unnecessary changes to make it closer to @mrxz 's original formulation. I think we should merge this in @dmarcos , @mrxz ! WDYT? This should really help with #225 .

Interestingly, because the sorting is so much faster, it sort of exposes the next bottleneck more: uploading the ordering frequently to the GPU can cause stuttering sometimes when the counts get large. Now this happens more often!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants